library(text)

train_path <- "/Users/yuwang/Desktop/PMethods/Binary/train_dev.csv"
test_path <- "/Users/yuwang/Desktop/PMethods/Binary/test.csv"
# Read the CSV file into a data frame
train <- read.csv(train_path)
test <- read.csv(test_path)

# Record start time
start_time <- Sys.time()
train_word_embeddings <- textEmbed(
  train,
  model = "roberta-base", aggregation_from_layers_to_tokens = "concatenate", aggregation_from_tokens_to_texts = "mean", keep_token_embeddings = FALSE
)
end_time <- Sys.time()
# Calculate duration
duration <- end_time - start_time
# Print the duration
cat("Execution duration for train embedding:", duration, "\n")

saveRDS(train_word_embeddings, "/Users/yuwang/Desktop/PMethods/Binary/supp_train_word_embeddings.rds")


# config.json: 100%|██████████| 481/481 [00:00<00:00, 291kB/s]
# tokenizer_config.json: 100%|██████████| 25.0/25.0 [00:00<00:00, 12.1kB/s]
# vocab.json: 100%|██████████| 899k/899k [00:00<00:00, 4.83MB/s]
# merges.txt: 100%|██████████| 456k/456k [00:00<00:00, 2.43MB/s]
# tokenizer.json: 100%|██████████| 1.36M/1.36M [00:01<00:00, 1.12MB/s]
# model.safetensors: 100%|██████████| 499M/499M [09:02<00:00, 920kB/s]  
# Completed layers output for text (variable: 1/1, duration: 17.010602 mins).
# Completed layers aggregation for word_type_embeddings. 
# Completed layers aggregation (variable 1/1, duration: 14.978222 mins).
# > end_time <- Sys.time()
# > # Calculate duration
# > duration <- end_time - start_time
# > # Print the duration
# > cat("Execution duration for train embedding:", duration, "\n")
# Execution duration for train embedding: 32.43794 


# Record start time
start_time <- Sys.time()
test_word_embeddings <- textEmbed(
  test,
  model = "roberta-base", aggregation_from_layers_to_tokens = "concatenate", aggregation_from_tokens_to_texts = "mean", keep_token_embeddings = FALSE
)
end_time <- Sys.time()
# Calculate duration
duration <- end_time - start_time
# Print the duration
cat("Execution duration for train embedding:", duration, "\n")
saveRDS(test_word_embeddings, "/Users/yuwang/Desktop/PMethods/Binary/supp_test_word_embeddings.rds")

# Completed layers output for text (variable: 1/1, duration: 1.668009 mins).
# Completed layers aggregation for word_type_embeddings. 
# Completed layers aggregation (variable 1/1, duration: 2.774775 mins).
# > end_time <- Sys.time()
# > # Calculate duration
#   > duration <- end_time - start_time
# > # Print the duration
#   > cat("Execution duration for train embedding:", duration, "\n")
# Execution duration for train embedding: 4.520847 

train_word_embeddings = readRDS("/Users/yuwang/Desktop/PMethods/Binary/supp_train_word_embeddings.rds")
test_word_embeddings = readRDS("/Users/yuwang/Desktop/PMethods/Binary/supp_test_word_embeddings.rds")

for (i in c(200, 500, 1000, 5000)){
  # Record start time
  start_time <- Sys.time()
  model <- textTrain(
    x = train_word_embeddings$texts$text[1:i,], # the predictor variables (i.e., the word embeddings)
    y = train$label[1:i], # the criterion variable (i.e., the rating scale score.)
    force_train_method = "random_forest"
  )
  end_time <- Sys.time()
  duration <- end_time - start_time
  cat("Training size:", i, "\n")
  print(duration)
  cat("Execution duration for training:", duration, "\n")
  
  predictions <- textPredict(model, word_embeddings = test_word_embeddings$texts)
  
  
  # Creating a confusion matrix
  conf_matrix <- table(Predicted = predictions$`text__cv_method="validation_split"pred`, Actual = test$label)
  
  # Extracting True Positives, False Positives, True Negatives, and False Negatives
  TP <- conf_matrix[2, 2]
  TN <- conf_matrix[1, 1]
  FP <- conf_matrix[2, 1]
  FN <- conf_matrix[1, 2]
  
  # Calculating Accuracy
  accuracy <- (TP + TN) / sum(conf_matrix)
  
  # Calculating Precision
  precision <- TP / (TP + FP)
  
  # Calculating Recall
  recall <- TP / (TP + FN)
  
  # Calculating F1
  f1 <- 2 * precision * recall / (precision + recall)
  
  # Printing the results
  cat("Accuracy:", accuracy, "\n")
  cat("Precision:", precision, "\n")
  cat("Recall:", recall, "\n")
  cat("F1:", f1, "\n")
}